{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "from PIL import Image\n",
    "import clip\n",
    "import os.path as osp\n",
    "import os, sys\n",
    "import torchvision.utils as vutils\n",
    "sys.path.insert(0, '../')\n",
    "\n",
    "from lib.utils import load_model_weights,mkdir_p\n",
    "from models.GALIP import NetG, CLIP_TXT_ENCODER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cpu' # 'cpu' # 'cuda:0'\n",
    "CLIP_text = \"ViT-B/32\"\n",
    "clip_model, preprocess = clip.load(\"ViT-B/32\", device=device)\n",
    "clip_model = clip_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_encoder = CLIP_TXT_ENCODER(clip_model).to(device)\n",
    "netG = NetG(64, 100, 512, 256, 3, False, clip_model).to(device)\n",
    "path = '../saved_models/pretrained/pre_cc12m.pth'\n",
    "checkpoint = torch.load(path, map_location=torch.device('cpu'))\n",
    "netG = load_model_weights(netG, checkpoint['model']['netG'], multi_gpus=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 8\n",
    "noise = torch.randn((batch_size, 100)).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "captions = ['Line drawing illustration of a kawaii cute ghost.']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "mkdir_p('./samples')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate from text\n",
    "with torch.no_grad():\n",
    "    for i in range(len(captions)):\n",
    "        caption = captions[i]\n",
    "        tokenized_text = clip.tokenize([caption]).to(device)\n",
    "        sent_emb, word_emb = text_encoder(tokenized_text)\n",
    "        sent_emb = sent_emb.repeat(batch_size,1)\n",
    "        fake_imgs = netG(noise,sent_emb,eval=True).float()\n",
    "        name = f'{captions[i].replace(\" \", \"-\")}'\n",
    "        vutils.save_image(fake_imgs.data, '../samples/%s.png'%(name), nrow=8, value_range=(-1, 1), normalize=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dfgan",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "849434eb86c3997df801551b732438d01b491303b69c29efcf332721ce6d8430"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
